#!/usr/bin/env python
import sys
import requests
import ssl

SOURCE1_NAME = "mysql-01"
SOURCE2_NAME = "mysql-02"
WORKER1_NAME = "worker1"
WORKER2_NAME = "worker2"


API_ENDPOINT = "http://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER = "http://127.0.0.1:8361/api/v1/sources"

API_ENDPOINT_HTTPS = "https://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER_HTTPS = "https://127.0.0.1:8361/api/v1/sources"



def create_source_failed():
    resp = requests.post(url=API_ENDPOINT)
    assert resp.status_code == 400
    print("create_source_failed resp=", resp.json())


def create_source1_success():
    req = {
        "source": {
            "case_sensitive": False,
            "enable": True,
            "enable_gtid": False,
            "host": "127.0.0.1",
            "password": "123456",
            "port": 3306,
            "source_name": SOURCE1_NAME,
            "user": "root",
        }
    }
    resp = requests.post(url=API_ENDPOINT, json=req)
    print("create_source1_success resp=", resp.json())
    assert resp.status_code == 201


def create_source2_success():
    req = {
        "source": {
            "enable": True,
            "case_sensitive": False,
            "enable_gtid": False,
            "host": "127.0.0.1",
            "password": "123456",
            "port": 3307,
            "source_name": SOURCE2_NAME,
            "user": "root",
        }
    }
    resp = requests.post(url=API_ENDPOINT, json=req)
    print("create_source1_success resp=", resp.json())
    assert resp.status_code == 201

def create_source_success_https(ssl_ca, ssl_cert, ssl_key):
    req = {
        "source": {
            "case_sensitive": False,
            "enable": True,
            "enable_gtid": False,
            "host": "127.0.0.1",
            "password": "123456",
            "port": 3306,
            "source_name": SOURCE1_NAME,
            "user": "root",
        }
    }
    resp = requests.post(url=API_ENDPOINT_HTTPS, json=req, verify=ssl_ca, cert=(ssl_cert, ssl_key))
    print("create_source_success_https resp=", resp.json())
    assert resp.status_code == 201

def update_source1_without_password_success():
    req = {
        "source": {
            "case_sensitive": False,
            "enable": True,
            "enable_gtid": False,
            "host": "127.0.0.1",
            "port": 3306,
            "source_name": SOURCE1_NAME,
            "user": "root",
        }
    }
    resp = requests.put(url=API_ENDPOINT + "/" + SOURCE1_NAME, json=req)
    print("update_source1_without_password_success resp=", resp.json())
    assert resp.status_code == 200

def list_source_success(source_count):
    resp = requests.get(url=API_ENDPOINT)
    assert resp.status_code == 200
    data = resp.json()
    print("list_source_by_openapi_success resp=", data)
    assert data["total"] == int(source_count)

def list_source_success_https(source_count, ssl_ca, ssl_cert, ssl_key):
    resp = requests.get(url=API_ENDPOINT_HTTPS, verify=ssl_ca, cert=(ssl_cert, ssl_key))
    assert resp.status_code == 200
    data = resp.json()
    print("list_source_success_https resp=", data)
    assert data["total"] == int(source_count)

def list_source_with_status_success(source_count, status_count):
    resp = requests.get(url=API_ENDPOINT + "?with_status=true")
    assert resp.status_code == 200
    data = resp.json()
    print("list_source_with_status_success resp=", data)
    assert data["total"] == int(source_count)
    for i in range(int(source_count)):
        assert len(data["data"][i]["status_list"]) == int(status_count)


def list_source_with_reverse(source_count):
    resp = requests.get(url=API_ENDPOINT_NOT_LEADER)
    assert resp.status_code == 200
    data = resp.json()
    print("list_source_with_reverse resp=", data)
    assert data["total"] == int(source_count)

def list_source_with_reverse_https(source_count, ssl_ca, ssl_cert, ssl_key):
    resp = requests.get(url=API_ENDPOINT_NOT_LEADER_HTTPS, verify=ssl_ca, cert=(ssl_cert, ssl_key))
    assert resp.status_code == 200
    data = resp.json()
    print("list_source_with_reverse_https resp=", data)
    assert data["total"] == int(source_count)

def delete_source_success(source_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + source_name)
    assert resp.status_code == 204
    print("delete_source_success")


def delete_source_with_force_success(source_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + source_name + "?force=true")
    assert resp.status_code == 204
    print("delete_source_with_force_success")


def delete_source_failed(source_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + source_name)
    print("delete_source_failed msg=", resp.json())
    assert resp.status_code == 400


def enable_relay_failed(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/relay/enable"
    req = {
        "worker_name_list": [worker_name],
    }
    resp = requests.post(url=url, json=req)
    print("enable_relay_failed resp=", resp.json())
    assert resp.status_code == 400


def enable_relay_success(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/relay/enable"
    req = {
        "purge": {"interval": 3600, "expires": 0, "remain_space": 15},
    }
    resp = requests.post(url=url, json=req)
    assert resp.status_code == 200


def enable_relay_success_with_two_worker(source_name, worker1_name, worker2_name):
    url = API_ENDPOINT + "/" + source_name + "/relay/enable"
    req = {
        "worker_name_list": [worker1_name, worker2_name],
        "purge": {"interval": 3600, "expires": 0, "remain_space": 15},
    }
    resp = requests.post(url=url, json=req)
    assert resp.status_code == 200


def disable_relay_failed(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/relay/disable"
    req = {
        "worker_name_list": [worker_name],
    }
    resp = requests.post(url=url, json=req)
    print("disable_relay_failed resp=", resp.json())
    assert resp.status_code == 400


def disable_relay_success(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/relay/disable"
    req = {
        "worker_name_list": [worker_name],
    }
    resp = requests.post(url=url, json=req)
    assert resp.status_code == 200


def purge_relay_success(source_name, relay_binlog_name, relay_dir=""):
    url = API_ENDPOINT + "/" + source_name + "/relay/purge"
    req = {
        "relay_binlog_name": relay_binlog_name,
        "relay_dir": relay_dir,
    }
    resp = requests.post(url=url, json=req)
    if resp.status_code != 200:
        print("purge relay  failed resp=", resp.json())
    assert resp.status_code == 200


def get_source_status_failed(source_name):
    url = API_ENDPOINT + "/" + source_name + "/status"
    resp = requests.get(url=url)
    print("get_source_status_failed resp=", resp.json())
    assert resp.status_code == 400


def get_source_status_success(source_name, status_count=1):
    url = API_ENDPOINT + "/" + source_name + "/status"
    resp = requests.get(url=url)
    assert resp.status_code == 200
    print("get_source_status_success resp=", resp.json())
    assert int(status_count) == resp.json()["total"]


def get_source_status_success_with_relay(source_name, source_idx=0):
    url = API_ENDPOINT + "/" + source_name + "/status"
    resp = requests.get(url=url)
    assert resp.status_code == 200
    res = resp.json()
    print("get_source_status_success_with_relay resp=", res)
    assert res["data"][int(source_idx)]["relay_status"] is not None


def get_source_status_success_no_relay(source_name, source_idx=0):
    url = API_ENDPOINT + "/" + source_name + "/status"
    resp = requests.get(url=url)
    assert resp.status_code == 200
    res = resp.json()
    print("get_source_status_success_no_relay resp=", res)
    assert res["data"][int(source_idx)].get("relay_status") is None


def transfer_source_success(source_name, worker_name):
    url = API_ENDPOINT + "/" + source_name + "/transfer"
    req = {
        "worker_name": worker_name,
    }
    resp = requests.post(url=url, json=req)
    assert resp.status_code == 200


def get_source_schemas_and_tables_success(source_name, schema_name="", table_name=""):
    schema_url = API_ENDPOINT + "/" + source_name + "/schemas"
    schema_resp = requests.get(url=schema_url)
    assert schema_resp.status_code == 200
    print("get_source_schemas_and_tables_success schema_resp=", schema_resp.json())
    schema_list = schema_resp.json()
    assert schema_name in schema_list

    table_url = API_ENDPOINT + "/" + source_name + "/schemas/openapi"
    table_resp = requests.get(url=table_url)
    print("get_source_schemas_and_tables_success table_resp=", table_resp.json())
    table_list = table_resp.json()
    assert table_name in table_list


if __name__ == "__main__":
    FUNC_MAP = {
        "create_source_failed": create_source_failed,
        "create_source1_success": create_source1_success,
        "create_source2_success": create_source2_success,
        "create_source_success_https": create_source_success_https,
        "update_source1_without_password_success": update_source1_without_password_success,
        "list_source_success": list_source_success,
        "list_source_success_https": list_source_success_https,
        "list_source_with_reverse_https": list_source_with_reverse_https,
        "list_source_with_reverse": list_source_with_reverse,
        "list_source_with_status_success": list_source_with_status_success,
        "delete_source_failed": delete_source_failed,
        "delete_source_success": delete_source_success,
        "delete_source_with_force_success": delete_source_with_force_success,
        "enable_relay_failed": enable_relay_failed,
        "enable_relay_success": enable_relay_success,
        "enable_relay_success_with_two_worker": enable_relay_success_with_two_worker,
        "disable_relay_failed": disable_relay_failed,
        "disable_relay_success": disable_relay_success,
        "purge_relay_success": purge_relay_success,
        "get_source_status_failed": get_source_status_failed,
        "get_source_status_success": get_source_status_success,
        "get_source_status_success_with_relay": get_source_status_success_with_relay,
        "get_source_status_success_no_relay": get_source_status_success_no_relay,
        "transfer_source_success": transfer_source_success,
        "get_source_schemas_and_tables_success": get_source_schemas_and_tables_success,
    }

    func = FUNC_MAP[sys.argv[1]]
    if len(sys.argv) >= 2:
        func(*sys.argv[2:])
    else:
        func()
